"""
问题类
"""
import json
from collections import Iterable, OrderedDict
from contextlib import contextmanager

import markdown

from core.my_sqlparser import get_type_table, build_format_sql, make_sql_key, \
    check_param_same_route

registered_questions = OrderedDict()


def register_que(cls):
    registered_questions[cls.tag] = cls
    return cls


class Solution:
    def __init__(self):
        self.condition = lambda x: True
        self.cal_flow = None
        self.res_builder = None


class QueBase:
    tag = ""
    comment = ""
    solution = None

    @classmethod
    def re_init_model(cls, url):
        from core.documents import APIModel
        model = APIModel.objects(url=url).first()
        if model:
            cls._re_init_model(model)
            model.save()

    @classmethod
    def _re_init_model(cls, model):
        raise ValueError

    @classmethod
    @contextmanager
    def register(cls):
        from .documents import QuestionState
        state = QuestionState.get_state(cls.tag)
        state.auto_solve_available = True
        state.save()
        s = Solution()
        yield s
        cls.solution.append(s)

    @classmethod
    def get_comment(cls):
        return cls.comment

    @classmethod
    def update_model(cls, model, answer):
        raise ValueError

    @classmethod
    def get_state(cls):
        from core.documents import QuestionState
        state = QuestionState.objects(que_type=cls.tag).first()
        if not state:
            state = QuestionState(que_type=cls.tag)
            state.save()
        return state

    @classmethod
    def can_solve(cls):
        return len(cls.solution) > 0

    @classmethod
    def solve(cls, question):
        from core.documents import APIModel,TaskProgress
        import random
        model = APIModel.objects(url=question.about).first()
        if not model or model.is_ignored():
            return
        for solv in registered_questions[question.que_type].solution:
            if not solv.condition(model):
                continue
            model.lastTaskKey = "[{}]-{}-{}".format(cls.tag, model.relate_url, random.random())
            task = TaskProgress(key = model.lastTaskKey,name="自动解决一个问题")
            task.save()
            res = solv.cal_flow(model)
            if isinstance(res, Iterable):
                res = list(res)
            question.result = solv.res_builder(res)
            question.solved = True
            question.save()
            task.finish()

            break

    @staticmethod
    def get_unsolved_question(que_type=None, about=None):
        from core.documents import Question, APIModel
        if que_type:
            types = [que_type]
        else:
            types = list(registered_questions)
        if about:
            urls = [about]
        else:
            urls = map(lambda x: x.url, APIModel.objects())
        for url in urls:
            for que_type in types:
                question = Question()
                question.about = url
                question.que_type = que_type
                question.key = question.make_key()
                solved = Question.objects(key=question.key, solved=True).first()
                if solved:
                    continue
                yield question

    @classmethod
    def get_about_question(cls, model):
        from core.documents import Question
        question = Question(que_type=cls.tag, about=model.url)
        question.key = question.make_key()
        t = Question.objects(key=question.key).first()
        if t:
            return t
        return question

    @classmethod
    def solve_problems(cls, url=None):
        """
        自动解决问题
        :return:
        """
        from core.documents import QuestionState
        from solvers import sol_order
        for order in sol_order:
            que_state = QuestionState.objects(que_type=order.tag).first()
            if not (que_state and que_state.auto_solve_available):
                continue
            if url is None:
                ite = cls.get_unsolved_question(que_type=order.tag)
            else:
                ite = cls.get_unsolved_question(que_type=order.tag, about=url)
            for question in ite:
                cls.solve(question)


@register_que
class IsLogin(QueBase):
    tag = 'IS_LOGIN'
    comment = "是否需要登陆"
    solution = []

    @classmethod
    def update_model(cls, model, answer):
        model.need_login = json.loads(answer)

    @classmethod
    def _re_init_model(cls, model):
        model.need_login = None


@register_que
class HasGroup(QueBase):
    tag = 'HAS_GROUP'
    comment = "访问权限组"
    solution = []

    @classmethod
    def update_model(cls, model, answer):
        model.need_login = json.loads(answer)

    @classmethod
    def _re_init_model(cls, model):
        model.user_group = []


@register_que
class ParamIsNecessary(QueBase):
    tag = 'PARAM_IS_NECESSARY'
    comment = "确认参数的必要性"
    solution = []

    @classmethod
    def update_model(cls, model, answer):
        for item in json.loads(answer):
            fresh_goal = model.args.get_sub_arg(item)
            fresh_goal.required = False
            fresh_goal.save()

    @classmethod
    def _re_init_model(cls, model):
        pass


@register_que
class ParamHasGroup(QueBase):
    tag = 'PARAM_HAS_GROUP'
    comment = "参数是否有分组，分组情况是怎么样的"
    solution = []

    @classmethod
    def update_model(cls, model, answer):
        for item in json.loads(answer):
            fresh_goal = model.args.get_sub_arg(item[0])
            fresh_goal.groups = item[1]
            fresh_goal.save()

    @classmethod
    def _re_init_model(cls, model):
        from core.compute_unit import remove_all_group
        remove_all_group(model)


@register_que
class ParamRange(QueBase):
    tag = 'PARAM_RANGE'
    comment = "各个参数的取值范围"
    solution = []

    @classmethod
    def update_model(cls, model, answer):
        for item in json.loads(answer):
            fresh_goal = model.args.get_sub_arg(item[0])
            fresh_goal.value_range = item[1]
            fresh_goal.save()

    @classmethod
    def _re_init_model(cls, model):
        pass


@register_que
class AboutSQL(QueBase):
    tag = 'ABOUT_SQL'
    comment = "API影响到的数据库"
    solution = []

    @classmethod
    def update_model(cls, model, answer):
        import sqlparse
        model.aboutSql = {}
        # 将变量提取出来，构建key，和变量映射表
        sql_variable = {}
        data_sql = []  # [[{请求参数},[(sql,key),]]]
        all_sql = []
        for param, sqls in json.loads(answer):
            data_sql.append([param, []])
            for sql in sqls:
                sql = sqlparse.parse(sql)[0]
                all_sql.append(sql)
                diff = []
                key = make_sql_key(sql, diff)
                if key not in sql_variable:
                    sql_variable[key] = diff
                data_sql[-1][1].append((sql, key))
        # 计算param 可能的对应值
        route_checkpoint = {}  # 参数对应的可能位置 {route:[(sql_key,sql_route,last_Identifier)]}
        warnings = check_param_same_route(data_sql, route_checkpoint, sql_variable)
        for warning in warnings:
            warning.url = model.url
            warning.save()
        for route in route_checkpoint:
            if route_checkpoint[route]:
                from .constant import RangeType
                arg = model.args.get_sub_arg(route)
                arg.value_range_type = RangeType.SQL_TABLE
                arg.value_range = '|'.join(set(str(r[2]) for r in route_checkpoint[route] if str(r[2])))
                arg.save()
        checked = set()
        for sql in all_sql:
            key = make_sql_key(sql, [])
            if key not in checked:
                checked.add(key)
            else:
                continue
            sql_type, tables = get_type_table(sql)
            for tb in tables:
                if tb not in model.aboutSql:
                    model.aboutSql[tb] = {}
                if sql_type not in model.aboutSql[tb]:
                    model.aboutSql[tb][sql_type] = []
                re_sql_route = []
                for route in route_checkpoint:
                    for t_key, t_route, _ in route_checkpoint[route]:
                        if key == t_key:
                            re_sql_route.append(t_route)
                model.aboutSql[tb][sql_type].append(
                    [str(sql), build_format_sql(sql, re_sql_route)])

    @classmethod
    def _re_init_model(cls, model):
        model.aboutSql = {}


@register_que
class APIComment(QueBase):
    tag = 'API_APPLICATION'
    comment = "API的注释是什么"
    solution = []

    @classmethod
    def update_model(cls, model, answer):
        model.comment = markdown.markdown(answer)

    @classmethod
    def _re_init_model(cls, model):
        model.comment = ""


@register_que
class ParamType(QueBase):
    tag = 'PARAM_TYPE'
    comment = "各个参数的类型"
    solution = []

    @classmethod
    def update_model(cls, model, answer):
        for item in json.loads(answer):
            fresh_goal = model.args.get_sub_arg(item[0])
            fresh_goal.val_type = item[1]
            fresh_goal.save()

    @classmethod
    def _re_init_model(cls, model):
        pass
